{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "import re\n",
    "import collections\n",
    "import random\n",
    "import pickle\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "maxlen = 20\n",
    "location = os.getcwd()\n",
    "num_layers = 3\n",
    "size_layer = 256\n",
    "learning_rate = 0.0001\n",
    "batch = 100\n",
    "batch_vector = 64\n",
    "filter_sizes = [2,3,4,5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('dataset-emotion.p', 'rb') as fopen:\n",
    "    df = pickle.load(fopen)\n",
    "with open('vector-emotion.p', 'rb') as fopen:\n",
    "    vectors = pickle.load(fopen)\n",
    "with open('dataset-dictionary.p', 'rb') as fopen:\n",
    "    dictionary = pickle.load(fopen)\n",
    "label = np.unique(df[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cross_validation import train_test_split\n",
    "train_X, test_X, train_Y, test_Y = train_test_split(df[:, 0], df[:, 1], test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(self, sequence_length, dimension_input, dimension_output, \n",
    "                 learning_rate, filter_sizes, out_dimension):\n",
    "        self.X = tf.placeholder(tf.float32, shape=[None, sequence_length, dimension_input, 1])\n",
    "        self.Y = tf.placeholder(tf.float32, shape=[None, dimension_output])\n",
    "        pooled_outputs = []\n",
    "        for i in filter_sizes:\n",
    "            w = tf.Variable(tf.truncated_normal([i, dimension_input, 1, out_dimension], stddev=0.1))\n",
    "            b = tf.Variable(tf.truncated_normal([out_dimension], stddev = 0.01))\n",
    "            conv = tf.nn.relu(tf.nn.conv2d(self.X, w, strides=[1, 1, 1, 1],padding=\"VALID\") + b)\n",
    "            pooled = tf.nn.max_pool(conv,ksize=[1, sequence_length - i + 1, 1, 1],strides=[1, 1, 1, 1],padding='VALID')\n",
    "            pooled_outputs.append(pooled)\n",
    "        h_pool = tf.concat(pooled_outputs, 3)\n",
    "        h_pool_flat = tf.nn.dropout(tf.reshape(h_pool, [-1, out_dimension * len(filter_sizes)]), 0.1)\n",
    "        w = tf.Variable(tf.truncated_normal([out_dimension * len(filter_sizes), dimension_output], stddev=0.1))\n",
    "        b = tf.Variable(tf.truncated_normal([dimension_output], stddev = 0.01))\n",
    "        self.logits = tf.matmul(h_pool_flat, w) + b\n",
    "        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = self.logits, labels = self.Y))\n",
    "        l2 = sum(0.0005 * tf.nn.l2_loss(tf_var) for tf_var in tf.trainable_variables())\n",
    "        self.cost += l2\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        self.correct_pred = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.Y, 1))\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0 , pass acc: 0 , current acc: 0.542124828585\n",
      "epoch: 1 , training loss: 3.35311372858 , training acc: 0.419598065902 , valid loss: 2.68321076962 , valid acc: 0.542124828585\n",
      "epoch: 1 , pass acc: 0.542124828585 , current acc: 0.711548598374\n",
      "epoch: 2 , training loss: 2.1633080282 , training acc: 0.636775623075 , valid loss: 1.72472805264 , valid acc: 0.711548598374\n",
      "epoch: 2 , pass acc: 0.711548598374 , current acc: 0.806350524257\n",
      "epoch: 3 , training loss: 1.4103115939 , training acc: 0.768092363662 , valid loss: 1.16824797794 , valid acc: 0.806350524257\n",
      "epoch: 3 , pass acc: 0.806350524257 , current acc: 0.849987989857\n",
      "epoch: 4 , training loss: 1.00042384904 , training acc: 0.835371913903 , valid loss: 0.872913800082 , valid acc: 0.849987989857\n",
      "epoch: 4 , pass acc: 0.849987989857 , current acc: 0.866938777116\n",
      "epoch: 5 , training loss: 0.779082060772 , training acc: 0.863641271512 , valid loss: 0.710237429351 , valid acc: 0.866938777116\n",
      "epoch: 5 , pass acc: 0.866938777116 , current acc: 0.877082837157\n",
      "epoch: 6 , training loss: 0.651938589099 , training acc: 0.876136776496 , valid loss: 0.612761208097 , valid acc: 0.877082837157\n",
      "epoch: 6 , pass acc: 0.877082837157 , current acc: 0.881800725609\n",
      "epoch: 7 , training loss: 0.575408671483 , training acc: 0.884025201622 , valid loss: 0.554598746036 , valid acc: 0.881800725609\n",
      "epoch: 7 , pass acc: 0.881800725609 , current acc: 0.885978399992\n",
      "epoch: 8 , training loss: 0.526045780403 , training acc: 0.888395328309 , valid loss: 0.513354932274 , valid acc: 0.885978399992\n",
      "epoch: 8 , pass acc: 0.885978399992 , current acc: 0.889843946745\n",
      "epoch: 9 , training loss: 0.494372778566 , training acc: 0.890605886968 , valid loss: 0.488349257004 , valid acc: 0.889843946745\n",
      "epoch: 10 , training loss: 0.47289045914 , training acc: 0.893149380188 , valid loss: 0.469870204184 , valid acc: 0.88929172345\n",
      "epoch: 10 , pass acc: 0.889843946745 , current acc: 0.890708293377\n",
      "epoch: 11 , training loss: 0.456699455551 , training acc: 0.894922026597 , valid loss: 0.457595595793 , valid acc: 0.890708293377\n",
      "epoch: 11 , pass acc: 0.890708293377 , current acc: 0.892460990782\n",
      "epoch: 12 , training loss: 0.446638513537 , training acc: 0.895899830443 , valid loss: 0.449041553083 , valid acc: 0.892460990782\n",
      "epoch: 12 , pass acc: 0.892460990782 , current acc: 0.89330132912\n",
      "epoch: 13 , training loss: 0.438888216169 , training acc: 0.896517707566 , valid loss: 0.441819753037 , valid acc: 0.89330132912\n",
      "epoch: 14 , training loss: 0.432883783031 , training acc: 0.897060598881 , valid loss: 0.437225613595 , valid acc: 0.893145267631\n",
      "epoch: 14 , pass acc: 0.89330132912 , current acc: 0.89490997362\n",
      "epoch: 15 , training loss: 0.428116644276 , training acc: 0.897390532758 , valid loss: 0.433677574881 , valid acc: 0.89490997362\n",
      "epoch: 16 , training loss: 0.42583482991 , training acc: 0.897828446045 , valid loss: 0.431342460457 , valid acc: 0.893781520811\n",
      "epoch: 16 , pass acc: 0.89490997362 , current acc: 0.895102050339\n",
      "epoch: 17 , training loss: 0.422297930827 , training acc: 0.898533303418 , valid loss: 0.42704702497 , valid acc: 0.895102050339\n",
      "epoch: 17 , pass acc: 0.895102050339 , current acc: 0.89647059724\n",
      "epoch: 18 , training loss: 0.420973101064 , training acc: 0.898458319059 , valid loss: 0.425223151998 , valid acc: 0.89647059724\n",
      "epoch: 19 , training loss: 0.418451083675 , training acc: 0.899013207712 , valid loss: 0.426160294934 , valid acc: 0.895042024812\n",
      "epoch: 20 , training loss: 0.416928572798 , training acc: 0.899022206351 , valid loss: 0.422884839816 , valid acc: 0.89621849664\n",
      "epoch: 21 , training loss: 0.415623114005 , training acc: 0.899139182451 , valid loss: 0.422401906741 , valid acc: 0.895294124697\n",
      "epoch: 22 , training loss: 0.415033490208 , training acc: 0.899511108498 , valid loss: 0.420268938321 , valid acc: 0.896062433791\n",
      "epoch: 23 , training loss: 0.413098079238 , training acc: 0.899892032576 , valid loss: 0.421312907807 , valid acc: 0.895558230516\n",
      "epoch: 24 , training loss: 0.412755277649 , training acc: 0.899055200722 , valid loss: 0.418815023759 , valid acc: 0.894777919279\n",
      "epoch: 25 , training loss: 0.412281025758 , training acc: 0.899970016708 , valid loss: 0.418546486194 , valid acc: 0.89481393365\n",
      "epoch: 26 , training loss: 0.412267824375 , training acc: 0.900095992429 , valid loss: 0.418234579918 , valid acc: 0.895642266769\n",
      "epoch: 26 , pass acc: 0.89647059724 , current acc: 0.896566637138\n",
      "epoch: 27 , training loss: 0.411158640351 , training acc: 0.899586093858 , valid loss: 0.417241873468 , valid acc: 0.896566637138\n",
      "epoch: 27 , pass acc: 0.896566637138 , current acc: 0.896914774773\n",
      "epoch: 28 , training loss: 0.410558740608 , training acc: 0.900569897697 , valid loss: 0.417743579287 , valid acc: 0.896914774773\n",
      "epoch: 29 , training loss: 0.410470487949 , training acc: 0.899745061913 , valid loss: 0.41812677855 , valid acc: 0.896110453025\n",
      "epoch: 30 , training loss: 0.409928769267 , training acc: 0.900407930746 , valid loss: 0.416308329171 , valid acc: 0.89539016667\n",
      "epoch: 31 , training loss: 0.409347614496 , training acc: 0.900212969191 , valid loss: 0.41667915285 , valid acc: 0.895930381024\n",
      "epoch: 31 , pass acc: 0.896914774773 , current acc: 0.897034823251\n",
      "epoch: 32 , training loss: 0.409543498399 , training acc: 0.899976015699 , valid loss: 0.412999733549 , valid acc: 0.897034823251\n",
      "epoch: 33 , training loss: 0.40863363957 , training acc: 0.900104991175 , valid loss: 0.416355710678 , valid acc: 0.895654269842\n",
      "epoch: 33 , pass acc: 0.897034823251 , current acc: 0.897046830188\n",
      "epoch: 34 , training loss: 0.408432540775 , training acc: 0.900386934929 , valid loss: 0.414057912661 , valid acc: 0.897046830188\n",
      "epoch: 35 , training loss: 0.408452539572 , training acc: 0.900233964542 , valid loss: 0.415226170746 , valid acc: 0.896506610252\n",
      "epoch: 36 , training loss: 0.407732585473 , training acc: 0.900365938022 , valid loss: 0.413706965092 , valid acc: 0.896218497999\n",
      "epoch: 36 , pass acc: 0.897046830188 , current acc: 0.897683084226\n",
      "epoch: 37 , training loss: 0.407665638677 , training acc: 0.900467917529 , valid loss: 0.413783486472 , valid acc: 0.897683084226\n",
      "epoch: 38 , training loss: 0.407763669215 , training acc: 0.899955021831 , valid loss: 0.41386583129 , valid acc: 0.89703482368\n",
      "epoch: 39 , training loss: 0.407630774652 , training acc: 0.900668877967 , valid loss: 0.414629629251 , valid acc: 0.896134461747\n",
      "epoch: 40 , training loss: 0.407213048783 , training acc: 0.900536904291 , valid loss: 0.414391046705 , valid acc: 0.896494605461\n",
      "epoch: 41 , training loss: 0.406902315324 , training acc: 0.900734865726 , valid loss: 0.412316621423 , valid acc: 0.897154871012\n",
      "epoch: 42 , training loss: 0.406912047615 , training acc: 0.900752860805 , valid loss: 0.413488631704 , valid acc: 0.896218496926\n",
      "epoch: 43 , training loss: 0.406759121932 , training acc: 0.900395931816 , valid loss: 0.414951637369 , valid acc: 0.896170477836\n",
      "epoch: 44 , training loss: 0.406277312769 , training acc: 0.900743863006 , valid loss: 0.412557104198 , valid acc: 0.896278517158\n",
      "epoch: 45 , training loss: 0.406819455151 , training acc: 0.900872836373 , valid loss: 0.413255342881 , valid acc: 0.896866754467\n",
      "epoch: 46 , training loss: 0.406384702676 , training acc: 0.900506910515 , valid loss: 0.412755171971 , valid acc: 0.896986806521\n",
      "epoch: 47 , training loss: 0.406851520408 , training acc: 0.901097791847 , valid loss: 0.414064356509 , valid acc: 0.895942385028\n",
      "break epoch: 47\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "dimension = vectors.shape[1]\n",
    "model = Model(maxlen, dimension, len(label), learning_rate, filter_sizes, size_layer)\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver(tf.global_variables())\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 10, 0, 0, 0\n",
    "while True:\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:', EPOCH)\n",
    "        break\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    for i in range(0, (train_X.shape[0] // batch) * batch, batch):\n",
    "        batch_x = np.zeros((batch, maxlen, dimension))\n",
    "        batch_y = np.zeros((batch, len(label)))\n",
    "        for k in range(batch):\n",
    "            tokens = train_X[i + k].split()[:maxlen]\n",
    "            for no, text in enumerate(tokens[::-1]):\n",
    "                try:\n",
    "                    batch_x[k, -1 - no, :] += vectors[dictionary[text], :]\n",
    "                except:\n",
    "                    continue\n",
    "            batch_y[k, int(train_Y[i + k])] = 1.0\n",
    "        batch_x = np.expand_dims(batch_x, axis=-1)\n",
    "        loss, _ = sess.run([model.cost, model.optimizer], feed_dict = {model.X : batch_x, model.Y : batch_y})\n",
    "        train_loss += loss\n",
    "        train_acc += sess.run(model.accuracy, feed_dict = {model.X : batch_x, model.Y : batch_y})\n",
    "    \n",
    "    for i in range(0, (test_X.shape[0] // batch) * batch, batch):\n",
    "        batch_x = np.zeros((batch, maxlen, dimension))\n",
    "        batch_y = np.zeros((batch, len(label)))\n",
    "        for k in range(batch):\n",
    "            tokens = test_X[i + k].split()[:maxlen]\n",
    "            for no, text in enumerate(tokens[::-1]):\n",
    "                try:\n",
    "                    batch_x[k, -1 - no, :] += vectors[dictionary[text], :]\n",
    "                except:\n",
    "                    continue\n",
    "            batch_y[k, int(test_Y[i + k])] = 1.0\n",
    "        batch_x = np.expand_dims(batch_x, axis=-1)\n",
    "        loss, acc = sess.run([model.cost, model.accuracy], feed_dict = {model.X : batch_x, model.Y : batch_y})\n",
    "        test_loss += loss\n",
    "        test_acc += acc\n",
    "        \n",
    "    train_loss /= (train_X.shape[0] // batch)\n",
    "    train_acc /= (train_X.shape[0] // batch)\n",
    "    test_loss /= (test_X.shape[0] // batch)\n",
    "    test_acc /= (test_X.shape[0] // batch)\n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print('epoch:', EPOCH, ', pass acc:', CURRENT_ACC, ', current acc:', test_acc)\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "        saver.save(sess, os.getcwd() + \"/model-cnn-vector.ckpt\")\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "    EPOCH += 1\n",
    "    print('epoch:', EPOCH, ', training loss:', train_loss, ', training acc:', train_acc, ', valid loss:', test_loss, ', valid acc:', test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
